diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py index d560d2a5ee0cc..16c1493e1e579 100644 --- a/airflow/providers/google/cloud/links/dataproc.py +++ b/airflow/providers/google/cloud/links/dataproc.py @@ -18,10 +18,12 @@ """This module contains Google Dataproc links.""" from __future__ import annotations +import warnings from typing import TYPE_CHECKING +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import BaseOperatorLink, XCom -from airflow.providers.google.cloud.links.base import BASE_LINK +from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink if TYPE_CHECKING: from airflow.models import BaseOperator @@ -29,21 +31,38 @@ from airflow.utils.context import Context DATAPROC_BASE_LINK = BASE_LINK + "/dataproc" -DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}" +DATAPROC_JOB_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" + DATAPROC_CLUSTER_LINK = ( - DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}" + DATAPROC_BASE_LINK + "/clusters/{cluster_id}/monitoring?region={region}&project={project_id}" ) DATAPROC_WORKFLOW_TEMPLATE_LINK = ( - DATAPROC_BASE_LINK + "/workflows/templates/{region}/{resource}?project={project_id}" + DATAPROC_BASE_LINK + "/workflows/templates/{region}/{workflow_template_id}?project={project_id}" ) -DATAPROC_WORKFLOW_LINK = DATAPROC_BASE_LINK + "/workflows/instances/{region}/{resource}?project={project_id}" -DATAPROC_BATCH_LINK = DATAPROC_BASE_LINK + "/batches/{region}/{resource}/monitoring?project={project_id}" +DATAPROC_WORKFLOW_LINK = ( + DATAPROC_BASE_LINK + "/workflows/instances/{region}/{workflow_id}?project={project_id}" +) + +DATAPROC_BATCH_LINK = DATAPROC_BASE_LINK + "/batches/{region}/{batch_id}/monitoring?project={project_id}" DATAPROC_BATCHES_LINK = DATAPROC_BASE_LINK + "/batches?project={project_id}" +DATAPROC_JOB_LINK_DEPRECATED = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}" +DATAPROC_CLUSTER_LINK_DEPRECATED = ( + DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}" +) class DataprocLink(BaseOperatorLink): - """Helper class for constructing Dataproc resource link.""" + """ + Helper class for constructing Dataproc resource link. + .. warning:: + This link is deprecated. + """ + + warnings.warn( + "This DataprocLink is deprecated.", + AirflowProviderDeprecationWarning, + ) name = "Dataproc resource" key = "conf" @@ -82,8 +101,14 @@ def get_link( class DataprocListLink(BaseOperatorLink): - """Helper class for constructing list of Dataproc resources link.""" + """ + Helper class for constructing list of Dataproc resources link. + + .. warning:: + This link is deprecated. + """ + warnings.warn("This DataprocListLink is deprecated.", AirflowProviderDeprecationWarning) name = "Dataproc resources" key = "list_conf" @@ -116,3 +141,127 @@ def get_link( if list_conf else "" ) + + +class DataprocClusterLink(BaseGoogleLink): + """Helper class for constructing Dataproc Cluster Link.""" + + name = "Dataproc Cluster" + key = "dataproc_cluster" + format_str = DATAPROC_CLUSTER_LINK + + @staticmethod + def persist( + context: Context, + operator: BaseOperator, + cluster_id: str, + region: str, + project_id: str, + ): + operator.xcom_push( + context, + key=DataprocClusterLink.key, + value={"cluster_id": cluster_id, "region": region, "project_id": project_id}, + ) + + +class DataprocJobLink(BaseGoogleLink): + """Helper class for constructing Dataproc Job Link.""" + + name = "Dataproc Job" + key = "dataproc_job" + format_str = DATAPROC_JOB_LINK + + @staticmethod + def persist( + context: Context, + operator: BaseOperator, + job_id: str, + region: str, + project_id: str, + ): + operator.xcom_push( + context, + key=DataprocJobLink.key, + value={"job_id": job_id, "region": region, "project_id": project_id}, + ) + + +class DataprocWorkflowLink(BaseGoogleLink): + """Helper class for constructing Dataproc Workflow Link.""" + + name = "Dataproc Workflow" + key = "dataproc_workflow" + format_str = DATAPROC_WORKFLOW_LINK + + @staticmethod + def persist(context: Context, operator: BaseOperator, workflow_id: str, project_id: str, region: str): + operator.xcom_push( + context, + key=DataprocWorkflowLink.key, + value={"workflow_id": workflow_id, "region": region, "project_id": project_id}, + ) + + +class DataprocWorkflowTemplateLink(BaseGoogleLink): + """Helper class for constructing Dataproc Workflow Template Link.""" + + name = "Dataproc Workflow Template" + key = "dataproc_workflow_template" + format_str = DATAPROC_WORKFLOW_TEMPLATE_LINK + + @staticmethod + def persist( + context: Context, + operator: BaseOperator, + workflow_template_id: str, + project_id: str, + region: str, + ): + operator.xcom_push( + context, + key=DataprocWorkflowTemplateLink.key, + value={"workflow_template_id": workflow_template_id, "region": region, "project_id": project_id}, + ) + + +class DataprocBatchLink(BaseGoogleLink): + """Helper class for constructing Dataproc Batch Link.""" + + name = "Dataproc Batch" + key = "dataproc_batch" + format_str = DATAPROC_BATCH_LINK + + @staticmethod + def persist( + context: Context, + operator: BaseOperator, + batch_id: str, + project_id: str, + region: str, + ): + operator.xcom_push( + context, + key=DataprocBatchLink.key, + value={"batch_id": batch_id, "region": region, "project_id": project_id}, + ) + + +class DataprocBatchesListLink(BaseGoogleLink): + """Helper class for constructing Dataproc Batches List Link.""" + + name = "Dataproc Batches List" + key = "dataproc_batches_list" + format_str = DATAPROC_BATCHES_LINK + + @staticmethod + def persist( + context: Context, + operator: BaseOperator, + project_id: str, + ): + operator.xcom_push( + context, + key=DataprocBatchesListLink.key, + value={"project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 423ebf988f935..09e76ae2069fe 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -43,13 +43,15 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.links.dataproc import ( DATAPROC_BATCH_LINK, - DATAPROC_BATCHES_LINK, - DATAPROC_CLUSTER_LINK, - DATAPROC_JOB_LOG_LINK, - DATAPROC_WORKFLOW_LINK, - DATAPROC_WORKFLOW_TEMPLATE_LINK, + DATAPROC_CLUSTER_LINK_DEPRECATED, + DATAPROC_JOB_LINK_DEPRECATED, + DataprocBatchesListLink, + DataprocBatchLink, + DataprocClusterLink, + DataprocJobLink, DataprocLink, - DataprocListLink, + DataprocWorkflowLink, + DataprocWorkflowTemplateLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.dataproc import ( @@ -189,6 +191,7 @@ def __init__( enable_component_gateway: bool | None = False, **kwargs, ) -> None: + self.project_id = project_id self.num_masters = num_masters self.num_workers = num_workers @@ -488,7 +491,7 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator): ) template_fields_renderers = {"cluster_config": "json", "virtual_cluster_config": "json"} - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocClusterLink(),) def __init__( self, @@ -629,9 +632,15 @@ def execute(self, context: Context) -> dict: self.log.info("Creating cluster: %s", self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name - ) + project_id = self.project_id or hook.project_id + if project_id: + DataprocClusterLink.persist( + context=context, + operator=self, + cluster_id=self.cluster_name, + project_id=project_id, + region=self.region, + ) try: # First try to create a new cluster operation = self._create_cluster(hook) @@ -814,7 +823,10 @@ def execute(self, context: Context) -> None: hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name + context=context, + task_instance=self, + url=DATAPROC_CLUSTER_LINK_DEPRECATED, + resource=self.cluster_name, ) operation = hook.update_cluster( project_id=self.project_id, @@ -1070,7 +1082,7 @@ def execute(self, context: Context): self.log.info("Job %s submitted successfully.", job_id) # Save data required for extra links no matter what the job status will be DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id + context=context, task_instance=self, url=DATAPROC_JOB_LINK_DEPRECATED, resource=job_id ) if self.deferrable: @@ -1669,7 +1681,7 @@ class DataprocCreateWorkflowTemplateOperator(GoogleCloudBaseOperator): template_fields: Sequence[str] = ("region", "template") template_fields_renderers = {"template": "json"} - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocWorkflowTemplateLink(),) def __init__( self, @@ -1709,12 +1721,15 @@ def execute(self, context: Context): self.log.info("Workflow %s created", workflow.name) except AlreadyExists: self.log.info("Workflow with given id already exists") - DataprocLink.persist( - context=context, - task_instance=self, - url=DATAPROC_WORKFLOW_TEMPLATE_LINK, - resource=self.template["id"], - ) + project_id = self.project_id or hook.project_id + if project_id: + DataprocWorkflowTemplateLink.persist( + context=context, + operator=self, + workflow_template_id=self.template["id"], + region=self.region, + project_id=project_id, + ) class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator): @@ -1759,7 +1774,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator): template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters") template_fields_renderers = {"parameters": "json"} - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocWorkflowLink(),) def __init__( self, @@ -1811,9 +1826,15 @@ def execute(self, context: Context): metadata=self.metadata, ) self.workflow_id = operation.operation.name.split("/")[-1] - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id - ) + project_id = self.project_id or hook.project_id + if project_id: + DataprocWorkflowLink.persist( + context=context, + operator=self, + workflow_id=self.workflow_id, + region=self.region, + project_id=project_id, + ) self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) if not self.deferrable: hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) @@ -1889,7 +1910,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator) template_fields: Sequence[str] = ("template", "impersonation_chain") template_fields_renderers = {"template": "json"} - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocWorkflowLink(),) def __init__( self, @@ -1926,9 +1947,10 @@ def __init__( def execute(self, context: Context): self.log.info("Instantiating Inline Template") hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + project_id = self.project_id or hook.project_id operation = hook.instantiate_inline_workflow_template( template=self.template, - project_id=self.project_id or hook.project_id, + project_id=project_id, region=self.region, request_id=self.request_id, retry=self.retry, @@ -1936,9 +1958,14 @@ def execute(self, context: Context): metadata=self.metadata, ) self.workflow_id = operation.operation.name.split("/")[-1] - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id - ) + if project_id: + DataprocWorkflowLink.persist( + context=context, + operator=self, + workflow_id=self.workflow_id, + region=self.region, + project_id=project_id, + ) if not self.deferrable: self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) operation.result() @@ -2010,7 +2037,7 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator): template_fields: Sequence[str] = ("project_id", "region", "job", "impersonation_chain", "request_id") template_fields_renderers = {"job": "json"} - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocJobLink(),) def __init__( self, @@ -2066,9 +2093,15 @@ def execute(self, context: Context): new_job_id: str = job_object.reference.job_id self.log.info("Job %s submitted successfully.", new_job_id) # Save data required by extra links no matter what the job status will be - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id - ) + project_id = self.project_id or self.hook.project_id + if project_id: + DataprocJobLink.persist( + context=context, + operator=self, + job_id=new_job_id, + region=self.region, + project_id=project_id, + ) self.job_id = new_job_id if self.deferrable: @@ -2168,7 +2201,7 @@ class DataprocUpdateClusterOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocClusterLink(),) def __init__( self, @@ -2210,9 +2243,15 @@ def __init__( def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required by extra links no matter what the cluster status will be - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name - ) + project_id = self.project_id or hook.project_id + if project_id: + DataprocClusterLink.persist( + context=context, + operator=self, + cluster_id=self.cluster_name, + project_id=project_id, + region=self.region, + ) self.log.info("Updating %s cluster.", self.cluster_name) operation = hook.update_cluster( project_id=self.project_id, @@ -2299,7 +2338,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator): "region", "impersonation_chain", ) - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocBatchLink(),) def __init__( self, @@ -2344,7 +2383,7 @@ def execute(self, context: Context): # batch_id might not be set and will be generated if self.batch_id: link = DATAPROC_BATCH_LINK.format( - region=self.region, project_id=self.project_id, resource=self.batch_id + region=self.region, project_id=self.project_id, batch_id=self.batch_id ) self.log.info("Creating batch %s", self.batch_id) self.log.info("Once started, the batch job will be available at %s", link) @@ -2423,7 +2462,17 @@ def execute(self, context: Context): wait_check_interval=self.polling_interval_seconds, ) batch_id = self.batch_id or result.name.split("/")[-1] + self.handle_batch_status(context, result.state, batch_id) + project_id = self.project_id or hook.project_id + if project_id: + DataprocBatchLink.persist( + context=context, + operator=self, + project_id=project_id, + region=self.region, + batch_id=batch_id, + ) return Batch.to_dict(result) def execute_complete(self, context, event=None) -> None: @@ -2446,24 +2495,14 @@ def handle_batch_status(self, context: Context, state: Batch.State, batch_id: st # The existing batch may be a number of states other than 'SUCCEEDED'\ # wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also # finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger - link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, resource=batch_id) + link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id) if state == Batch.State.FAILED: - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id - ) raise AirflowException("Batch job %s failed. Driver Logs: %s", batch_id, link) if state in (Batch.State.CANCELLED, Batch.State.CANCELLING): - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id - ) raise AirflowException("Batch job %s was cancelled. Driver logs: %s", batch_id, link) if state == Batch.State.STATE_UNSPECIFIED: - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id - ) raise AirflowException("Batch job %s unspecified. Driver logs: %s", batch_id, link) self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link) - DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id) class DataprocDeleteBatchOperator(GoogleCloudBaseOperator): @@ -2554,7 +2593,7 @@ class DataprocGetBatchOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ("batch_id", "region", "project_id", "impersonation_chain") - operator_extra_links = (DataprocLink(),) + operator_extra_links = (DataprocBatchLink(),) def __init__( self, @@ -2590,9 +2629,15 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=self.batch_id - ) + project_id = self.project_id or hook.project_id + if project_id: + DataprocBatchLink.persist( + context=context, + operator=self, + project_id=project_id, + region=self.region, + batch_id=self.batch_id, + ) return Batch.to_dict(batch) @@ -2624,7 +2669,7 @@ class DataprocListBatchesOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ("region", "project_id", "impersonation_chain") - operator_extra_links = (DataprocListLink(),) + operator_extra_links = (DataprocBatchesListLink(),) def __init__( self, @@ -2668,7 +2713,9 @@ def execute(self, context: Context): filter=self.filter, order_by=self.order_by, ) - DataprocListLink.persist(context=context, task_instance=self, url=DATAPROC_BATCHES_LINK) + project_id = self.project_id or hook.project_id + if project_id: + DataprocBatchesListLink.persist(context=context, operator=self, project_id=project_id) return [Batch.to_dict(result) for result in results] diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 1b96ca972a616..f43b088720be0 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1063,6 +1063,12 @@ extra-links: - airflow.providers.google.cloud.links.datacatalog.DataCatalogTagTemplateLink - airflow.providers.google.cloud.links.dataproc.DataprocLink - airflow.providers.google.cloud.links.dataproc.DataprocListLink + - airflow.providers.google.cloud.links.dataproc.DataprocClusterLink + - airflow.providers.google.cloud.links.dataproc.DataprocJobLink + - airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink + - airflow.providers.google.cloud.links.dataproc.DataprocWorkflowTemplateLink + - airflow.providers.google.cloud.links.dataproc.DataprocBatchLink + - airflow.providers.google.cloud.links.dataproc.DataprocBatchesListLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink - airflow.providers.google.cloud.links.dataprep.DataprepFlowLink diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 38e4ffeef5fdf..6ddaec8be8928 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -32,10 +32,14 @@ TaskDeferred, ) from airflow.models import DAG, DagBag +from airflow.providers.google.cloud.links.dataproc import ( + DATAPROC_CLUSTER_LINK_DEPRECATED, + DATAPROC_JOB_LINK_DEPRECATED, + DataprocClusterLink, + DataprocJobLink, + DataprocWorkflowLink, +) from airflow.providers.google.cloud.operators.dataproc import ( - DATAPROC_CLUSTER_LINK, - DATAPROC_JOB_LOG_LINK, - DATAPROC_WORKFLOW_LINK, ClusterGenerator, DataprocCreateBatchOperator, DataprocCreateClusterOperator, @@ -241,19 +245,28 @@ "resource": TEST_JOB_ID, "region": GCP_REGION, "project_id": GCP_PROJECT, - "url": DATAPROC_JOB_LOG_LINK, + "url": DATAPROC_JOB_LINK_DEPRECATED, +} +DATAPROC_JOB_EXPECTED = { + "job_id": TEST_JOB_ID, + "region": GCP_REGION, + "project_id": GCP_PROJECT, } DATAPROC_CLUSTER_CONF_EXPECTED = { "resource": CLUSTER_NAME, "region": GCP_REGION, "project_id": GCP_PROJECT, - "url": DATAPROC_CLUSTER_LINK, + "url": DATAPROC_CLUSTER_LINK_DEPRECATED, } -DATAPROC_WORKFLOW_CONF_EXPECTED = { - "resource": TEST_WORKFLOW_ID, +DATAPROC_CLUSTER_EXPECTED = { + "cluster_id": CLUSTER_NAME, + "region": GCP_REGION, + "project_id": GCP_PROJECT, +} +DATAPROC_WORKFLOW_EXPECTED = { + "workflow_id": TEST_WORKFLOW_ID, "region": GCP_REGION, "project_id": GCP_PROJECT, - "url": DATAPROC_WORKFLOW_LINK, } BATCH_ID = "test-batch-id" BATCH = { @@ -306,7 +319,7 @@ class DataprocClusterTestBase(DataprocTestBase): def setup_class(cls): super().setup_class() cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + call.ti.xcom_push(execution_date=None, key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) ] @@ -488,8 +501,8 @@ def test_execute(self, mock_hook, to_dict_mock): to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) self.mock_ti.xcom_push.assert_called_once_with( - key="conf", - value=DATAPROC_CLUSTER_CONF_EXPECTED, + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, execution_date=None, ) @@ -537,8 +550,8 @@ def test_execute_in_gke(self, mock_hook, to_dict_mock): to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation()) self.mock_ti.xcom_push.assert_called_once_with( - key="conf", - value=DATAPROC_CLUSTER_CONF_EXPECTED, + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, execution_date=None, ) @@ -742,28 +755,35 @@ def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_ # Assert operator links for serialized DAG assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}} + {"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink": {}} ] # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) + assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == "" # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) # Assert operator links are preserved in deserialized tasks after execution - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED class TestDataprocClusterScaleOperator(DataprocClusterTestBase): + @classmethod + def setup_class(cls): + super().setup_class() + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ] + def test_deprecation_warning(self): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) @@ -847,7 +867,10 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o # Assert operator link is empty for deserialized task when no XCom push occurred assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push( + key="conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + ) # Assert operator links are preserved in deserialized tasks after execution assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED @@ -929,7 +952,9 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): - xcom_push_call = call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED) + xcom_push_call = call.ti.xcom_push( + execution_date=None, key="dataproc_job", value=DATAPROC_JOB_EXPECTED + ) wait_for_job_call = call.hook().wait_for_job( job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None ) @@ -976,7 +1001,7 @@ def test_execute(self, mock_hook): ) self.mock_ti.xcom_push.assert_called_once_with( - key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1016,7 +1041,7 @@ def test_execute_async(self, mock_hook): mock_hook.return_value.wait_for_job.assert_not_called() self.mock_ti.xcom_push.assert_called_once_with( - key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1185,25 +1210,25 @@ def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_insta # Assert operator links for serialized_dag assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}} + {"airflow.providers.google.cloud.links.dataproc.DataprocJobLink": {}} ] # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) + assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink) # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.get_extra_links(ti, DataprocJobLink.name) == "" # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED) + ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED + assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED + assert ti.task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): @@ -1251,8 +1276,8 @@ def test_execute(self, mock_hook): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) self.mock_ti.xcom_push.assert_called_once_with( - key="conf", - value=DATAPROC_CLUSTER_CONF_EXPECTED, + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, execution_date=None, ) @@ -1342,25 +1367,25 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_ # Assert operator links for serialized_dag assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}} + {"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink": {}} ] # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) + assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == "" # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED class TestDataprocWorkflowTemplateInstantiateOperator: @@ -1448,25 +1473,25 @@ def test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_ # Assert operator links for serialized_dag assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}} + {"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink": {}} ] # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) + assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink) # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == "" # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_WORKFLOW_CONF_EXPECTED) + ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED class TestDataprocWorkflowTemplateInstantiateInlineOperator: @@ -1548,25 +1573,25 @@ def test_instantiate_inline_workflow_operator_extra_links( # Assert operator links for serialized_dag assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}} + {"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink": {}} ] # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) + assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink) # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == "" # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == "" - ti.xcom_push(key="conf", value=DATAPROC_WORKFLOW_CONF_EXPECTED) + ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED class TestDataProcHiveOperator: @@ -1789,6 +1814,13 @@ def test_builder(self, mock_hook, mock_uuid): class TestDataProcSparkOperator(DataprocJobTestBase): + @classmethod + def setup_class(cls): + cls.extra_links_expected_calls = [ + call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT), + ] + main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job_name = "simple"