Skip to content

Commit

Permalink
Add explicit support of stream (realtime) pipelines for CloudDataFusi…
Browse files Browse the repository at this point in the history
…onStartPipelineOperator (#34271)
  • Loading branch information
moiseenkov committed Sep 11, 2023
1 parent 6042e76 commit 4dcdc34
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 10 deletions.
48 changes: 42 additions & 6 deletions airflow/providers/google/cloud/hooks/datafusion.py
Expand Up @@ -31,6 +31,7 @@
from googleapiclient.discovery import Resource, build

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
Expand Down Expand Up @@ -105,6 +106,7 @@ def wait_for_pipeline_state(
pipeline_name: str,
pipeline_id: str,
instance_url: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
success_states: list[str] | None = None,
failure_states: list[str] | None = None,
Expand All @@ -120,6 +122,7 @@ def wait_for_pipeline_state(
workflow = self.get_pipeline_workflow(
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_type=pipeline_type,
instance_url=instance_url,
namespace=namespace,
)
Expand Down Expand Up @@ -432,13 +435,14 @@ def get_pipeline_workflow(
pipeline_name: str,
instance_url: str,
pipeline_id: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
) -> Any:
url = os.path.join(
self._base_url(instance_url, namespace),
quote(pipeline_name),
"workflows",
"DataPipelineWorkflow",
f"{self.cdap_program_type(pipeline_type=pipeline_type)}s",
self.cdap_program_id(pipeline_type=pipeline_type),
"runs",
quote(pipeline_id),
)
Expand All @@ -453,13 +457,15 @@ def start_pipeline(
self,
pipeline_name: str,
instance_url: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
runtime_args: dict[str, Any] | None = None,
) -> str:
"""
Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines.
:param pipeline_name: Your pipeline name.
:param pipeline_type: Optional pipeline type (BATCH by default).
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param runtime_args: Optional runtime JSON args to be passed to the pipeline
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
Expand All @@ -480,9 +486,9 @@ def start_pipeline(
body = [
{
"appId": pipeline_name,
"programType": "workflow",
"programId": "DataPipelineWorkflow",
"runtimeargs": runtime_args,
"programType": self.cdap_program_type(pipeline_type=pipeline_type),
"programId": self.cdap_program_id(pipeline_type=pipeline_type),
}
]
response = self._cdap_request(url=url, method="POST", body=body)
Expand Down Expand Up @@ -514,6 +520,30 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str =
response, f"Stopping a pipeline failed with code {response.status}"
)

@staticmethod
def cdap_program_type(pipeline_type: DataFusionPipelineType) -> str:
"""Retrieves CDAP Program type depending on the pipeline type.
:param pipeline_type: Pipeline type.
"""
program_types = {
DataFusionPipelineType.BATCH: "workflow",
DataFusionPipelineType.STREAM: "spark",
}
return program_types.get(pipeline_type, "")

@staticmethod
def cdap_program_id(pipeline_type: DataFusionPipelineType) -> str:
"""Retrieves CDAP Program id depending on the pipeline type.
:param pipeline_type: Pipeline type.
"""
program_ids = {
DataFusionPipelineType.BATCH: "DataPipelineWorkflow",
DataFusionPipelineType.STREAM: "DataStreamsSparkStreaming",
}
return program_ids.get(pipeline_type, "")


class DataFusionAsyncHook(GoogleBaseAsyncHook):
"""Class to get asynchronous hook for DataFusion."""
Expand Down Expand Up @@ -561,10 +591,13 @@ async def get_pipeline(
pipeline_name: str,
pipeline_id: str,
session,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
):
program_type = self.sync_hook_class.cdap_program_type(pipeline_type=pipeline_type)
program_id = self.sync_hook_class.cdap_program_id(pipeline_type=pipeline_type)
base_url_link = self._base_url(instance_url, namespace)
url = urljoin(
base_url_link, f"{quote(pipeline_name)}/workflows/DataPipelineWorkflow/runs/{quote(pipeline_id)}"
base_url_link, f"{quote(pipeline_name)}/{program_type}s/{program_id}/runs/{quote(pipeline_id)}"
)
return await self._get_link(url=url, session=session)

Expand All @@ -573,6 +606,7 @@ async def get_pipeline_status(
pipeline_name: str,
instance_url: str,
pipeline_id: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
namespace: str = "default",
success_states: list[str] | None = None,
) -> str:
Expand All @@ -581,7 +615,8 @@ async def get_pipeline_status(
:param pipeline_name: Your pipeline name.
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param pipeline_id: Unique pipeline ID associated with specific pipeline
:param pipeline_id: Unique pipeline ID associated with specific pipeline.
:param pipeline_type: Optional pipeline type (by default batch).
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
is always default. If your pipeline belongs to an Enterprise edition instance, you
can create a namespace.
Expand All @@ -596,6 +631,7 @@ async def get_pipeline_status(
namespace=namespace,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_type=pipeline_type,
session=session,
)
pipeline = await pipeline.json(content_type=None)
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/google/cloud/operators/datafusion.py
Expand Up @@ -33,6 +33,7 @@
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -708,6 +709,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
:ref:`howto/operator:CloudDataFusionStartPipelineOperator`
:param pipeline_name: Your pipeline name.
:param pipeline_type: Optional pipeline type (BATCH by default).
:param instance_name: The name of the instance.
:param success_states: If provided the operator will wait for pipeline to be in one of
the provided states.
Expand Down Expand Up @@ -752,6 +754,7 @@ def __init__(
pipeline_name: str,
instance_name: str,
location: str,
pipeline_type: DataFusionPipelineType = DataFusionPipelineType.BATCH,
runtime_args: dict[str, Any] | None = None,
success_states: list[str] | None = None,
namespace: str = "default",
Expand All @@ -767,6 +770,7 @@ def __init__(
) -> None:
super().__init__(**kwargs)
self.pipeline_name = pipeline_name
self.pipeline_type = pipeline_type
self.runtime_args = runtime_args
self.namespace = namespace
self.instance_name = instance_name
Expand Down Expand Up @@ -800,6 +804,7 @@ def execute(self, context: Context) -> str:
api_url = instance["apiEndpoint"]
pipeline_id = hook.start_pipeline(
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
instance_url=api_url,
namespace=self.namespace,
runtime_args=self.runtime_args,
Expand All @@ -824,6 +829,7 @@ def execute(self, context: Context) -> str:
instance_url=api_url,
namespace=self.namespace,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type.value,
pipeline_id=pipeline_id,
poll_interval=self.poll_interval,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -839,6 +845,7 @@ def execute(self, context: Context) -> str:
success_states=self.success_states,
pipeline_id=pipeline_id,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
namespace=self.namespace,
instance_url=api_url,
timeout=self.pipeline_timeout,
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/google/cloud/triggers/datafusion.py
Expand Up @@ -20,6 +20,7 @@
from typing import Any, AsyncIterator, Sequence

from airflow.providers.google.cloud.hooks.datafusion import DataFusionAsyncHook
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand All @@ -30,6 +31,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
:param pipeline_name: Your pipeline name.
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
:param pipeline_id: Unique pipeline ID associated with specific pipeline
:param pipeline_type: Your pipeline type.
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
is always default. If your pipeline belongs to an Enterprise edition instance, you
can create a namespace.
Expand All @@ -51,6 +53,7 @@ def __init__(
namespace: str,
pipeline_name: str,
pipeline_id: str,
pipeline_type: str,
poll_interval: float = 3.0,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -61,6 +64,7 @@ def __init__(
self.namespace = namespace
self.pipeline_name = pipeline_name
self.pipeline_id = pipeline_id
self.pipeline_type = pipeline_type
self.poll_interval = poll_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
Expand All @@ -76,6 +80,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"namespace": self.namespace,
"pipeline_name": self.pipeline_name,
"pipeline_id": self.pipeline_id,
"pipeline_type": self.pipeline_type,
"success_states": self.success_states,
},
)
Expand All @@ -92,6 +97,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
namespace=self.namespace,
pipeline_name=self.pipeline_name,
pipeline_id=self.pipeline_id,
pipeline_type=DataFusionPipelineType.from_str(self.pipeline_type),
)
if response_from_hook == "success":
yield TriggerEvent(
Expand Down
33 changes: 33 additions & 0 deletions airflow/providers/google/cloud/utils/datafusion.py
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from enum import Enum


class DataFusionPipelineType(Enum):
"""Enum for Data Fusion pipeline types."""

BATCH = "batch"
STREAM = "stream"

@staticmethod
def from_str(value: str) -> DataFusionPipelineType:
value_to_item = {item.value: item for item in DataFusionPipelineType}
if value in value_to_item:
return value_to_item[value]
raise ValueError(f"Invalid value '{value}'. Valid values are: {[i for i in value_to_item.keys()]}")

0 comments on commit 4dcdc34

Please sign in to comment.