Skip to content

Commit

Permalink
Add deferrable AzureDataFactoryRunPipelineOperator (#30147)
Browse files Browse the repository at this point in the history
* Add deferrable mode to AzureDataFactoryRunPipelineOperator

* Add deferrable mode to AzureDataFactoryRunPipelineOperator

* Fix docs
  • Loading branch information
phanikumv committed Mar 17, 2023
1 parent e09d00e commit c99201a
Show file tree
Hide file tree
Showing 7 changed files with 460 additions and 15 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ class AzureDataFactoryPipelineRunStatus:
FAILED = "Failed"
CANCELING = "Canceling"
CANCELLED = "Cancelled"

TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED}
INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING}
FAILURE_STATES = {FAILED, CANCELLED}


class AzureDataFactoryPipelineRunException(AirflowException):
Expand Down
66 changes: 53 additions & 13 deletions airflow/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
from __future__ import annotations

import time
import warnings
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.microsoft.azure.hooks.data_factory import (
Expand All @@ -26,6 +29,7 @@
AzureDataFactoryPipelineRunStatus,
get_field,
)
from airflow.providers.microsoft.azure.triggers.data_factory import AzureDataFactoryTrigger
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
Expand Down Expand Up @@ -102,6 +106,7 @@ class AzureDataFactoryRunPipelineOperator(BaseOperator):
waits. Used only if ``wait_for_termination`` is True.
:param check_interval: Time in seconds to check on a pipeline run's status for non-asynchronous waits.
Used only if ``wait_for_termination`` is True.
:param deferrable: Run operator in deferrable mode.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -133,6 +138,7 @@ def __init__(
parameters: dict[str, Any] | None = None,
timeout: int = 60 * 60 * 24 * 7,
check_interval: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -148,6 +154,7 @@ def __init__(
self.parameters = parameters
self.timeout = timeout
self.check_interval = check_interval
self.deferrable = deferrable

def execute(self, context: Context) -> None:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
Expand All @@ -169,21 +176,54 @@ def execute(self, context: Context) -> None:
context["ti"].xcom_push(key="run_id", value=self.run_id)

if self.wait_for_termination:
self.log.info("Waiting for pipeline run %s to terminate.", self.run_id)

if self.hook.wait_for_pipeline_run_status(
run_id=self.run_id,
expected_statuses=AzureDataFactoryPipelineRunStatus.SUCCEEDED,
check_interval=self.check_interval,
timeout=self.timeout,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
):
self.log.info("Pipeline run %s has completed successfully.", self.run_id)
if self.deferrable is False:
self.log.info("Waiting for pipeline run %s to terminate.", self.run_id)

if self.hook.wait_for_pipeline_run_status(
run_id=self.run_id,
expected_statuses=AzureDataFactoryPipelineRunStatus.SUCCEEDED,
check_interval=self.check_interval,
timeout=self.timeout,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
):
self.log.info("Pipeline run %s has completed successfully.", self.run_id)
else:
raise AzureDataFactoryPipelineRunException(
f"Pipeline run {self.run_id} has failed or has been cancelled."
)
else:
raise AzureDataFactoryPipelineRunException(
f"Pipeline run {self.run_id} has failed or has been cancelled."
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=AzureDataFactoryTrigger(
azure_data_factory_conn_id=self.azure_data_factory_conn_id,
run_id=self.run_id,
wait_for_termination=self.wait_for_termination,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
check_interval=self.check_interval,
end_time=end_time,
),
method_name="execute_complete",
)
else:
if self.deferrable is True:
warnings.warn(
"Argument `wait_for_termination` is False and `deferrable` is True , hence "
"`deferrable` parameter doesn't have any effect",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])

def on_kill(self) -> None:
if self.run_id:
Expand Down
105 changes: 105 additions & 0 deletions airflow/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import time
from typing import Any, AsyncIterator

from airflow.providers.microsoft.azure.hooks.data_factory import (
Expand Down Expand Up @@ -89,3 +90,107 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})


class AzureDataFactoryTrigger(BaseTrigger):
"""
AzureDataFactoryTrigger is triggered when Azure data factory pipeline job succeeded or failed.
When wait_for_termination is set to False it triggered immediately with success status
:param run_id: Run id of a Azure data pipeline run job.
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
:param end_time: Time in seconds when triggers will timeout.
:param resource_group_name: The resource group name.
:param factory_name: The data factory name.
:param wait_for_termination: Flag to wait on a pipeline run's termination.
:param check_interval: Time in seconds to check on a pipeline run's status.
"""

def __init__(
self,
run_id: str,
azure_data_factory_conn_id: str,
end_time: float,
resource_group_name: str | None = None,
factory_name: str | None = None,
wait_for_termination: bool = True,
check_interval: int = 60,
):
super().__init__()
self.azure_data_factory_conn_id = azure_data_factory_conn_id
self.check_interval = check_interval
self.run_id = run_id
self.wait_for_termination = wait_for_termination
self.resource_group_name = resource_group_name
self.factory_name = factory_name
self.end_time = end_time

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes AzureDataFactoryTrigger arguments and classpath."""
return (
"airflow.providers.microsoft.azure.triggers.data_factory.AzureDataFactoryTrigger",
{
"azure_data_factory_conn_id": self.azure_data_factory_conn_id,
"check_interval": self.check_interval,
"run_id": self.run_id,
"wait_for_termination": self.wait_for_termination,
"resource_group_name": self.resource_group_name,
"factory_name": self.factory_name,
"end_time": self.end_time,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Make async connection to Azure Data Factory, polls for the pipeline run status"""
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
try:
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
if self.wait_for_termination:
while self.end_time > time.time():
pipeline_status = await hook.get_adf_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
yield TriggerEvent(
{
"status": "error",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
self.log.info(
"Sleeping for %s. The pipeline state is %s.", self.check_interval, pipeline_status
)
await asyncio.sleep(self.check_interval)

yield TriggerEvent(
{
"status": "error",
"message": f"Timeout: The pipeline run {self.run_id} has {pipeline_status}.",
"run_id": self.run_id,
}
)
else:
yield TriggerEvent(
{
"status": "success",
"message": f"The pipeline run {self.run_id} has {pipeline_status} status.",
"run_id": self.run_id,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ Below is an example of using this operator to execute an Azure Data Factory pipe
:start-after: [START howto_operator_adf_run_pipeline]
:end-before: [END howto_operator_adf_run_pipeline]

Below is an example of using this operator to execute an Azure Data Factory pipeline with a deferrable flag
so that polling for the status of the pipeline run occurs on the Airflow Triggerer.

.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
:language: python
:dedent: 4
:start-after: [START howto_operator_adf_run_pipeline_with_deferrable_flag]
:end-before: [END howto_operator_adf_run_pipeline_with_deferrable_flag]

Here is a different example of using this operator to execute a pipeline but coupled with the :class:`~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor` to perform an asynchronous wait.

.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_adf_run_pipeline.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@
from __future__ import annotations

import json
from unittest import mock
from unittest.mock import MagicMock, patch

import pendulum
import pytest

from airflow.models import Connection
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG, Connection
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
)
from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator
from airflow.providers.microsoft.azure.triggers.data_factory import AzureDataFactoryTrigger
from airflow.utils import db, timezone
from airflow.utils.types import DagRunType

DEFAULT_DATE = timezone.datetime(2021, 1, 1)
SUBSCRIPTION_ID = "my-subscription-id"
Expand All @@ -48,6 +56,7 @@
"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/"
"factories/{factory_name}"
)
AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"


class TestAzureDataFactoryRunPipelineOperator:
Expand Down Expand Up @@ -252,3 +261,89 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i
factory_name=factory if factory else conn_factory_name,
)
)


class TestAzureDataFactoryRunPipelineOperatorWithDeferrable:
OPERATOR = AzureDataFactoryRunPipelineOperator(
task_id="run_pipeline", pipeline_name="pipeline", parameters={"myParam": "value"}, deferrable=True
)

def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = "test_dag_id") -> DagRun:
dag_run = DagRun(
dag_id=dag_id, run_type="manual", execution_date=timezone.datetime(2022, 1, 1), run_id=run_id
)
return dag_run

def get_task_instance(self, task: BaseOperator) -> TaskInstance:
return TaskInstance(task, timezone.datetime(2022, 1, 1))

def get_conn(
self,
) -> Connection:
return Connection(
conn_id="test_conn",
extra={},
)

def create_context(self, task, dag=None):
if dag is None:
dag = DAG(dag_id="dag")
tzinfo = pendulum.timezone("UTC")
execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
dag_run = DagRun(
dag_id=dag.dag_id,
execution_date=execution_date,
run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
)

task_instance = TaskInstance(task=task)
task_instance.dag_run = dag_run
task_instance.xcom_push = mock.Mock()
return {
"dag": dag,
"ts": execution_date.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
"run_id": dag_run.run_id,
"dag_run": dag_run,
"execution_date": execution_date,
"data_interval_end": execution_date,
"logical_date": execution_date,
}

@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline")
def test_azure_data_factory_run_pipeline_operator_async(self, mock_run_pipeline):
"""Assert that AzureDataFactoryRunPipelineOperatorAsync deferred"""

class CreateRunResponse:
pass

CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID
mock_run_pipeline.return_value = CreateRunResponse

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(context=self.create_context(self.OPERATOR))

assert isinstance(
exc.value.trigger, AzureDataFactoryTrigger
), "Trigger is not a AzureDataFactoryTrigger"

def test_azure_data_factory_run_pipeline_operator_async_execute_complete_success(self):
"""Assert that execute_complete log success message"""

with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info:
self.OPERATOR.execute_complete(
context={},
event={"status": "success", "message": "success", "run_id": AZ_PIPELINE_RUN_ID},
)
mock_log_info.assert_called_with("success")

def test_azure_data_factory_run_pipeline_operator_async_execute_complete_fail(self):
"""Assert that execute_complete raise exception on error"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context={},
event={"status": "error", "message": "error", "run_id": AZ_PIPELINE_RUN_ID},
)
Loading

0 comments on commit c99201a

Please sign in to comment.