diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index 4b6ac2151a2f3..3ddeeb222b9b4 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -22,8 +22,11 @@ from enum import Enum from functools import wraps from inspect import signature -from typing import Any, Callable, Sequence, Set +from typing import Any, Callable, Sequence, Set, TypeVar, cast +import aiohttp +from aiohttp import ClientResponseError +from asgiref.sync import sync_to_async from requests import PreparedRequest, Session from requests.auth import AuthBase from requests.models import Response @@ -125,6 +128,34 @@ class DbtCloudJobRunException(AirflowException): """An exception that indicates a job run failed to complete.""" +T = TypeVar("T", bound=Any) + + +def provide_account_id(func: T) -> T: + """ + Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed + to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection. + """ + function_signature = signature(func) + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + bound_args = function_signature.bind(*args, **kwargs) + + if bound_args.arguments.get("account_id") is None: + self = args[0] + if self.dbt_cloud_conn_id: + connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) + default_account_id = connection.login + if not default_account_id: + raise AirflowException("Could not determine the dbt Cloud account.") + bound_args.arguments["account_id"] = int(default_account_id) + + return await func(*bound_args.args, **bound_args.kwargs) + + return cast(T, wrapper) + + class DbtCloudHook(HttpHook): """ Interact with dbt Cloud using the V2 API. @@ -150,6 +181,83 @@ def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) super().__init__(auth_type=TokenAuth) self.dbt_cloud_conn_id = dbt_cloud_conn_id + @staticmethod + def get_request_url_params( + tenant: str, endpoint: str, include_related: list[str] | None = None + ) -> tuple[str, dict[str, Any]]: + """ + Form URL from base url and endpoint url + + :param tenant: The tenant name which is need to be replaced in base url. + :param endpoint: Endpoint url to be requested. + :param include_related: Optional. List of related fields to pull with the run. + Valid values are "trigger", "job", "repository", and "environment". + """ + data: dict[str, Any] = {} + base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/" + if include_related: + data = {"include_related": include_related} + if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"): + url = base_url + "/" + endpoint + else: + url = (base_url or "") + (endpoint or "") + return url, data + + async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]: + """Get Headers, tenants from the connection details""" + headers: dict[str, Any] = {} + connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) + tenant: str = connection.schema if connection.schema else "cloud" + package_name, provider_version = _get_provider_info() + headers["User-Agent"] = f"{package_name}-v{provider_version}" + headers["Content-Type"] = "application/json" + headers["Authorization"] = f"Token {connection.password}" + return headers, tenant + + @provide_account_id + async def get_job_details( + self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None + ) -> Any: + """ + Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job. + + :param run_id: The ID of a dbt Cloud job run. + :param account_id: Optional. The ID of a dbt Cloud account. + :param include_related: Optional. List of related fields to pull with the run. + Valid values are "trigger", "job", "repository", and "environment". + """ + endpoint = f"{account_id}/runs/{run_id}/" + headers, tenant = await self.get_headers_tenants_from_connection() + url, params = self.get_request_url_params(tenant, endpoint, include_related) + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get(url, params=params) as response: + try: + response.raise_for_status() + return await response.json() + except ClientResponseError as e: + raise AirflowException(str(e.status) + ":" + e.message) + + async def get_job_status( + self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None + ) -> int: + """ + Retrieves the status for a specific run of a dbt Cloud job. + + :param run_id: The ID of a dbt Cloud job run. + :param account_id: Optional. The ID of a dbt Cloud account. + :param include_related: Optional. List of related fields to pull with the run. + Valid values are "trigger", "job", "repository", and "environment". + """ + try: + self.log.info("Getting the status of job run %s.", str(run_id)) + response = await self.get_job_details( + run_id, account_id=account_id, include_related=include_related + ) + job_run_status: int = response["data"]["status"] + return job_run_status + except Exception as e: + raise e + @cached_property def connection(self) -> Connection: _connection = self.get_connection(self.dbt_cloud_conn_id) diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index 472b2ffa7f09d..f65ce077d3422 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -17,10 +17,14 @@ from __future__ import annotations import json +import time +import warnings from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus +from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -63,6 +67,7 @@ class DbtCloudRunJobOperator(BaseOperator): Used only if ``wait_for_termination`` is True. Defaults to 60 seconds. :param additional_run_config: Optional. Any additional parameters that should be included in the API request when triggering the job. + :param deferrable: Run operator in the deferrable mode :return: The ID of the triggered dbt Cloud job run. """ @@ -91,6 +96,7 @@ def __init__( timeout: int = 60 * 60 * 24 * 7, check_interval: int = 60, additional_run_config: dict[str, Any] | None = None, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -106,8 +112,9 @@ def __init__( self.additional_run_config = additional_run_config or {} self.hook: DbtCloudHook self.run_id: int + self.deferrable = deferrable - def execute(self, context: Context) -> int: + def execute(self, context: Context): if self.trigger_reason is None: self.trigger_reason = ( f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." @@ -129,20 +136,52 @@ def execute(self, context: Context) -> int: context["ti"].xcom_push(key="job_run_url", value=job_run_url) if self.wait_for_termination: - self.log.info("Waiting for job run %s to terminate.", str(self.run_id)) - - if self.hook.wait_for_job_run_status( - run_id=self.run_id, - account_id=self.account_id, - expected_statuses=DbtCloudJobRunStatus.SUCCESS.value, - check_interval=self.check_interval, - timeout=self.timeout, - ): - self.log.info("Job run %s has completed successfully.", str(self.run_id)) + if self.deferrable is False: + self.log.info("Waiting for job run %s to terminate.", str(self.run_id)) + + if self.hook.wait_for_job_run_status( + run_id=self.run_id, + account_id=self.account_id, + expected_statuses=DbtCloudJobRunStatus.SUCCESS.value, + check_interval=self.check_interval, + timeout=self.timeout, + ): + self.log.info("Job run %s has completed successfully.", str(self.run_id)) + else: + raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.") + + return self.run_id else: - raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.") - - return self.run_id + end_time = time.time() + self.timeout + self.defer( + timeout=self.execution_timeout, + trigger=DbtCloudRunJobTrigger( + conn_id=self.dbt_cloud_conn_id, + run_id=self.run_id, + end_time=end_time, + account_id=self.account_id, + poll_interval=self.check_interval, + ), + method_name="execute_complete", + ) + else: + if self.deferrable is True: + warnings.warn( + "Argument `wait_for_termination` is False and `deferrable` is True , hence " + "`deferrable` parameter doesn't have any effect", + ) + return self.run_id + + def execute_complete(self, context: "Context", event: dict[str, Any]) -> int: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + return int(event["run_id"]) def on_kill(self) -> None: if self.run_id: diff --git a/airflow/providers/dbt/cloud/provider.yaml b/airflow/providers/dbt/cloud/provider.yaml index ad2817eb8e2eb..4315f9c272529 100644 --- a/airflow/providers/dbt/cloud/provider.yaml +++ b/airflow/providers/dbt/cloud/provider.yaml @@ -34,6 +34,8 @@ versions: dependencies: - apache-airflow>=2.3.0 - apache-airflow-providers-http + - asgiref + - aiohttp integrations: - integration-name: dbt Cloud diff --git a/airflow/providers/dbt/cloud/triggers/__init__.py b/airflow/providers/dbt/cloud/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/dbt/cloud/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/dbt/cloud/triggers/dbt.py b/airflow/providers/dbt/cloud/triggers/dbt.py new file mode 100644 index 0000000000000..9bad789a5246f --- /dev/null +++ b/airflow/providers/dbt/cloud/triggers/dbt.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import time +from typing import Any, AsyncIterator + +from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunStatus +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class DbtCloudRunJobTrigger(BaseTrigger): + """ + DbtCloudRunJobTrigger is triggered with run id and account id, makes async Http call to dbt and + get the status for the submitted job with run id in polling interval of time. + + :param conn_id: The connection identifier for connecting to Dbt. + :param run_id: The ID of a dbt Cloud job. + :param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. + :param account_id: The ID of a dbt Cloud account. + :param poll_interval: polling period in seconds to check for the status. + """ + + def __init__( + self, + conn_id: str, + run_id: int, + end_time: float, + poll_interval: float, + account_id: int | None, + ): + super().__init__() + self.run_id = run_id + self.account_id = account_id + self.conn_id = conn_id + self.end_time = end_time + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes DbtCloudRunJobTrigger arguments and classpath.""" + return ( + "airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger", + { + "run_id": self.run_id, + "account_id": self.account_id, + "conn_id": self.conn_id, + "end_time": self.end_time, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + """Make async connection to Dbt, polls for the pipeline run status""" + hook = DbtCloudHook(self.conn_id) + try: + while await self.is_still_running(hook): + if self.end_time < time.time(): + yield TriggerEvent( + { + "status": "error", + "message": f"Job run {self.run_id} has not reached a terminal status after " + f"{self.end_time} seconds.", + "run_id": self.run_id, + } + ) + await asyncio.sleep(self.poll_interval) + job_run_status = await hook.get_job_status(self.run_id, self.account_id) + if job_run_status == DbtCloudJobRunStatus.SUCCESS.value: + yield TriggerEvent( + { + "status": "success", + "message": f"Job run {self.run_id} has completed successfully.", + "run_id": self.run_id, + } + ) + elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value: + yield TriggerEvent( + { + "status": "cancelled", + "message": f"Job run {self.run_id} has been cancelled.", + "run_id": self.run_id, + } + ) + else: + yield TriggerEvent( + { + "status": "error", + "message": f"Job run {self.run_id} has failed.", + "run_id": self.run_id, + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id}) + + async def is_still_running(self, hook: DbtCloudHook) -> bool: + """ + Async function to check whether the job is submitted via async API is in + running state and returns True if it is still running else + return False + """ + job_run_status = await hook.get_job_status(self.run_id, self.account_id) + if not DbtCloudJobRunStatus.is_terminal(job_run_status): + return True + return False diff --git a/docs/apache-airflow-providers-dbt-cloud/operators.rst b/docs/apache-airflow-providers-dbt-cloud/operators.rst index de5b0b8060293..1f7b27b2808d8 100644 --- a/docs/apache-airflow-providers-dbt-cloud/operators.rst +++ b/docs/apache-airflow-providers-dbt-cloud/operators.rst @@ -40,6 +40,18 @@ execution time. This functionality is controlled by the ``wait_for_termination`` :class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`). Setting ``wait_for_termination`` to False is a good approach for long-running dbt Cloud jobs. +The ``deferrable`` parameter along with ``wait_for_termination`` will control the functionality +whether to poll the job status on the worker or defer using the Triggerer. +When ``wait_for_termination`` is True and ``deferrable`` is False,we submit the job and ``poll`` +for its status on the worker. This will keep the worker slot occupied till the job execution is done. +When ``wait_for_termination`` is True and ``deferrable`` is True, +we submit the job and ``defer`` using Triggerer. This will release the worker slot leading to savings in +resource utilization while the job is running. + +When ``wait_for_termination`` is False and ``deferrable`` is False, we just submit the job and can only +track the job status with the :class:`~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor`. + + While ``schema_override`` and ``steps_override`` are explicit, optional parameters for the ``DbtCloudRunJobOperator``, custom run configurations can also be passed to the operator using the ``additional_run_config`` dictionary. This parameter can be used to initialize additional runtime diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 90f878796bb95..8f42cded79975 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -241,8 +241,10 @@ }, "dbt.cloud": { "deps": [ + "aiohttp", "apache-airflow-providers-http", - "apache-airflow>=2.3.0" + "apache-airflow>=2.3.0", + "asgiref" ], "cross-providers-deps": [ "http"