Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions airflow/providers/google/cloud/hooks/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from googleapiclient.discovery import Resource, build

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
Expand Down Expand Up @@ -105,6 +106,7 @@ def wait_for_pipeline_state(
pipeline_name: str,
pipeline_id: str,
instance_url: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
success_states: list[str] | None = None,
failure_states: list[str] | None = None,
Expand All @@ -120,6 +122,7 @@ def wait_for_pipeline_state(
workflow = self.get_pipeline_workflow(
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_type=pipeline_type,
instance_url=instance_url,
namespace=namespace,
)
Expand Down Expand Up @@ -432,13 +435,14 @@ def get_pipeline_workflow(
pipeline_name: str,
instance_url: str,
pipeline_id: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
) -> Any:
url = os.path.join(
self._base_url(instance_url, namespace),
quote(pipeline_name),
"workflows",
"DataPipelineWorkflow",
f"{self.cdap_program_type(pipeline_type=pipeline_type)}s",
self.cdap_program_id(pipeline_type=pipeline_type),
"runs",
quote(pipeline_id),
)
Expand All @@ -453,13 +457,15 @@ def start_pipeline(
self,
pipeline_name: str,
instance_url: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
runtime_args: dict[str, Any] | None = None,
) -> str:
"""
Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines.

:param pipeline_name: Your pipeline name.
:param pipeline_type: Optional pipeline type (BATCH by default).
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param runtime_args: Optional runtime JSON args to be passed to the pipeline
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
Expand All @@ -480,9 +486,9 @@ def start_pipeline(
body = [
{
"appId": pipeline_name,
"programType": "workflow",
"programId": "DataPipelineWorkflow",
"runtimeargs": runtime_args,
"programType": self.cdap_program_type(pipeline_type=pipeline_type),
"programId": self.cdap_program_id(pipeline_type=pipeline_type),
}
]
response = self._cdap_request(url=url, method="POST", body=body)
Expand Down Expand Up @@ -514,6 +520,30 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str =
response, f"Stopping a pipeline failed with code {response.status}"
)

@staticmethod
def cdap_program_type(pipeline_type: DataFusionPipelineType) -> str:
"""Retrieves CDAP Program type depending on the pipeline type.

:param pipeline_type: Pipeline type.
"""
program_types = {
DataFusionPipelineType.BATCH: "workflow",
DataFusionPipelineType.STREAM: "spark",
}
return program_types.get(pipeline_type, "")

@staticmethod
def cdap_program_id(pipeline_type: DataFusionPipelineType) -> str:
"""Retrieves CDAP Program id depending on the pipeline type.

:param pipeline_type: Pipeline type.
"""
program_ids = {
DataFusionPipelineType.BATCH: "DataPipelineWorkflow",
DataFusionPipelineType.STREAM: "DataStreamsSparkStreaming",
}
return program_ids.get(pipeline_type, "")


class DataFusionAsyncHook(GoogleBaseAsyncHook):
"""Class to get asynchronous hook for DataFusion."""
Expand Down Expand Up @@ -561,10 +591,13 @@ async def get_pipeline(
pipeline_name: str,
pipeline_id: str,
session,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
):
program_type = self.sync_hook_class.cdap_program_type(pipeline_type=pipeline_type)
program_id = self.sync_hook_class.cdap_program_id(pipeline_type=pipeline_type)
base_url_link = self._base_url(instance_url, namespace)
url = urljoin(
base_url_link, f"{quote(pipeline_name)}/workflows/DataPipelineWorkflow/runs/{quote(pipeline_id)}"
base_url_link, f"{quote(pipeline_name)}/{program_type}s/{program_id}/runs/{quote(pipeline_id)}"
)
return await self._get_link(url=url, session=session)

Expand All @@ -573,6 +606,7 @@ async def get_pipeline_status(
pipeline_name: str,
instance_url: str,
pipeline_id: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
success_states: list[str] | None = None,
) -> str:
Expand All @@ -581,7 +615,8 @@ async def get_pipeline_status(

:param pipeline_name: Your pipeline name.
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param pipeline_id: Unique pipeline ID associated with specific pipeline
:param pipeline_id: Unique pipeline ID associated with specific pipeline.
:param pipeline_type: Optional pipeline type (by default batch).
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
is always default. If your pipeline belongs to an Enterprise edition instance, you
can create a namespace.
Expand All @@ -596,6 +631,7 @@ async def get_pipeline_status(
namespace=namespace,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_type=pipeline_type,
session=session,
)
pipeline = await pipeline.json(content_type=None)
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/google/cloud/operators/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -708,6 +709,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
:ref:`howto/operator:CloudDataFusionStartPipelineOperator`

:param pipeline_name: Your pipeline name.
:param pipeline_type: Optional pipeline type (BATCH by default).
:param instance_name: The name of the instance.
:param success_states: If provided the operator will wait for pipeline to be in one of
the provided states.
Expand Down Expand Up @@ -752,6 +754,7 @@ def __init__(
pipeline_name: str,
instance_name: str,
location: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
runtime_args: dict[str, Any] | None = None,
success_states: list[str] | None = None,
namespace: str = "default",
Expand All @@ -767,6 +770,7 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.pipeline_name = pipeline_name
self.pipeline_type = pipeline_type
self.runtime_args = runtime_args
self.namespace = namespace
self.instance_name = instance_name
Expand Down Expand Up @@ -800,6 +804,7 @@ def execute(self, context: Context) -> str:
api_url = instance["apiEndpoint"]
pipeline_id = hook.start_pipeline(
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
instance_url=api_url,
namespace=self.namespace,
runtime_args=self.runtime_args,
Expand All @@ -824,6 +829,7 @@ def execute(self, context: Context) -> str:
instance_url=api_url,
namespace=self.namespace,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type.value,
pipeline_id=pipeline_id,
poll_interval=self.poll_interval,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -839,6 +845,7 @@ def execute(self, context: Context) -> str:
success_states=self.success_states,
pipeline_id=pipeline_id,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
namespace=self.namespace,
instance_url=api_url,
timeout=self.pipeline_timeout,
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/google/cloud/triggers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, AsyncIterator, Sequence

from airflow.providers.google.cloud.hooks.datafusion import DataFusionAsyncHook
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand All @@ -30,6 +31,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
:param pipeline_name: Your pipeline name.
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param pipeline_id: Unique pipeline ID associated with specific pipeline
:param pipeline_type: Your pipeline type.
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
is always default. If your pipeline belongs to an Enterprise edition instance, you
can create a namespace.
Expand All @@ -51,6 +53,7 @@ def __init__(
namespace: str,
pipeline_name: str,
pipeline_id: str,
pipeline_type: str,
poll_interval: float = 3.0,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -61,6 +64,7 @@ def __init__(
self.namespace = namespace
self.pipeline_name = pipeline_name
self.pipeline_id = pipeline_id
self.pipeline_type = pipeline_type
self.poll_interval = poll_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
Expand All @@ -76,6 +80,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"namespace": self.namespace,
"pipeline_name": self.pipeline_name,
"pipeline_id": self.pipeline_id,
"pipeline_type": self.pipeline_type,
"success_states": self.success_states,
},
)
Expand All @@ -92,6 +97,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
namespace=self.namespace,
pipeline_name=self.pipeline_name,
pipeline_id=self.pipeline_id,
pipeline_type=DataFusionPipelineType.from_str(self.pipeline_type),
)
if response_from_hook == "success":
yield TriggerEvent(
Expand Down
33 changes: 33 additions & 0 deletions airflow/providers/google/cloud/utils/datafusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from enum import Enum


class DataFusionPipelineType(Enum):
"""Enum for Data Fusion pipeline types."""

BATCH = "batch"
STREAM = "stream"

@staticmethod
def from_str(value: str) -> DataFusionPipelineType:
value_to_item = {item.value: item for item in DataFusionPipelineType}
if value in value_to_item:
return value_to_item[value]
raise ValueError(f"Invalid value '{value}'. Valid values are: {[i for i in value_to_item.keys()]}")
Loading