Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode to DbtCloudRunJobOperator #29014

Merged
merged 4 commits into from Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 109 additions & 1 deletion airflow/providers/dbt/cloud/hooks/dbt.py
Expand Up @@ -22,8 +22,11 @@
from enum import Enum
from functools import wraps
from inspect import signature
from typing import Any, Callable, Sequence, Set
from typing import Any, Callable, Sequence, Set, TypeVar, cast

import aiohttp
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async
from requests import PreparedRequest, Session
from requests.auth import AuthBase
from requests.models import Response
Expand Down Expand Up @@ -125,6 +128,34 @@ class DbtCloudJobRunException(AirflowException):
"""An exception that indicates a job run failed to complete."""


T = TypeVar("T", bound=Any)


def provide_account_id(func: T) -> T:
"""
Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed
to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection.
"""
function_signature = signature(func)
Comment on lines +134 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this wrapper (as well as other same decorators which already exists in community providers) that every time we access to method which decorated by this function we need to lookup for connection in Secrets Backend (if configured), Environment Variables and Airflow Database.

Is any way to make it cashable? Something like that: #28716

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will check if it can be made cacheable


@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = function_signature.bind(*args, **kwargs)

if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
default_account_id = connection.login
if not default_account_id:
raise AirflowException("Could not determine the dbt Cloud account.")
bound_args.arguments["account_id"] = int(default_account_id)

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)


class DbtCloudHook(HttpHook):
"""
Interact with dbt Cloud using the V2 API.
Expand All @@ -150,6 +181,83 @@ def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs)
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id

@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: list[str] | None = None
) -> tuple[str, dict[str, Any]]:
"""
Form URL from base url and endpoint url

:param tenant: The tenant name which is need to be replaced in base url.
:param endpoint: Endpoint url to be requested.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
data: dict[str, Any] = {}
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
if include_related:
data = {"include_related": include_related}
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = base_url + "/" + endpoint
else:
url = (base_url or "") + (endpoint or "")
return url, data

async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
"""Get Headers, tenants from the connection details"""
headers: dict[str, Any] = {}
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
tenant: str = connection.schema if connection.schema else "cloud"
package_name, provider_version = _get_provider_info()
headers["User-Agent"] = f"{package_name}-v{provider_version}"
headers["Content-Type"] = "application/json"
headers["Authorization"] = f"Token {connection.password}"
return headers, tenant

@provide_account_id
async def get_job_details(
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None
) -> Any:
"""
Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job.

:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
endpoint = f"{account_id}/runs/{run_id}/"
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint, include_related)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(url, params=params) as response:
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
raise AirflowException(str(e.status) + ":" + e.message)

async def get_job_status(
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None
) -> int:
"""
Retrieves the status for a specific run of a dbt Cloud job.

:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
try:
self.log.info("Getting the status of job run %s.", str(run_id))
response = await self.get_job_details(
run_id, account_id=account_id, include_related=include_related
)
job_run_status: int = response["data"]["status"]
return job_run_status
except Exception as e:
raise e

@cached_property
def connection(self) -> Connection:
_connection = self.get_connection(self.dbt_cloud_conn_id)
Expand Down
67 changes: 53 additions & 14 deletions airflow/providers/dbt/cloud/operators/dbt.py
Expand Up @@ -17,10 +17,14 @@
from __future__ import annotations

import json
import time
import warnings
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus
from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -63,6 +67,7 @@ class DbtCloudRunJobOperator(BaseOperator):
Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:param deferrable: Run operator in the deferrable mode
:return: The ID of the triggered dbt Cloud job run.
"""

Expand Down Expand Up @@ -91,6 +96,7 @@ def __init__(
timeout: int = 60 * 60 * 24 * 7,
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -106,8 +112,9 @@ def __init__(
self.additional_run_config = additional_run_config or {}
self.hook: DbtCloudHook
self.run_id: int
self.deferrable = deferrable

def execute(self, context: Context) -> int:
def execute(self, context: Context):
if self.trigger_reason is None:
self.trigger_reason = (
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG."
Expand All @@ -129,20 +136,52 @@ def execute(self, context: Context) -> int:
context["ti"].xcom_push(key="job_run_url", value=job_run_url)

if self.wait_for_termination:
self.log.info("Waiting for job run %s to terminate.", str(self.run_id))

if self.hook.wait_for_job_run_status(
run_id=self.run_id,
account_id=self.account_id,
expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
check_interval=self.check_interval,
timeout=self.timeout,
):
self.log.info("Job run %s has completed successfully.", str(self.run_id))
if self.deferrable is False:
self.log.info("Waiting for job run %s to terminate.", str(self.run_id))

if self.hook.wait_for_job_run_status(
run_id=self.run_id,
account_id=self.account_id,
expected_statuses=DbtCloudJobRunStatus.SUCCESS.value,
check_interval=self.check_interval,
timeout=self.timeout,
):
self.log.info("Job run %s has completed successfully.", str(self.run_id))
else:
raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.")

return self.run_id
else:
raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.")

return self.run_id
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=self.run_id,
end_time=end_time,
account_id=self.account_id,
poll_interval=self.check_interval,
),
method_name="execute_complete",
)
else:
if self.deferrable is True:
warnings.warn(
"Argument `wait_for_termination` is False and `deferrable` is True , hence "
"`deferrable` parameter doesn't have any effect",
)
return self.run_id

def execute_complete(self, context: "Context", event: dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
return int(event["run_id"])

def on_kill(self) -> None:
if self.run_id:
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/dbt/cloud/provider.yaml
Expand Up @@ -34,6 +34,8 @@ versions:
dependencies:
- apache-airflow>=2.3.0
- apache-airflow-providers-http
- asgiref
- aiohttp

integrations:
- integration-name: dbt Cloud
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/dbt/cloud/triggers/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
119 changes: 119 additions & 0 deletions airflow/providers/dbt/cloud/triggers/dbt.py
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
import time
from typing import Any, AsyncIterator

from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunStatus
from airflow.triggers.base import BaseTrigger, TriggerEvent


class DbtCloudRunJobTrigger(BaseTrigger):
"""
DbtCloudRunJobTrigger is triggered with run id and account id, makes async Http call to dbt and
get the status for the submitted job with run id in polling interval of time.

:param conn_id: The connection identifier for connecting to Dbt.
:param run_id: The ID of a dbt Cloud job.
:param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
"""

def __init__(
self,
conn_id: str,
run_id: int,
end_time: float,
poll_interval: float,
account_id: int | None,
):
super().__init__()
self.run_id = run_id
self.account_id = account_id
self.conn_id = conn_id
self.end_time = end_time
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DbtCloudRunJobTrigger arguments and classpath."""
return (
"airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger",
{
"run_id": self.run_id,
"account_id": self.account_id,
"conn_id": self.conn_id,
"end_time": self.end_time,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Make async connection to Dbt, polls for the pipeline run status"""
hook = DbtCloudHook(self.conn_id)
try:
while await self.is_still_running(hook):
if self.end_time < time.time():
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.run_id} has not reached a terminal status after "
f"{self.end_time} seconds.",
"run_id": self.run_id,
}
)
await asyncio.sleep(self.poll_interval)
job_run_status = await hook.get_job_status(self.run_id, self.account_id)
if job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
yield TriggerEvent(
{
"status": "success",
"message": f"Job run {self.run_id} has completed successfully.",
"run_id": self.run_id,
}
)
elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value:
yield TriggerEvent(
{
"status": "cancelled",
"message": f"Job run {self.run_id} has been cancelled.",
"run_id": self.run_id,
}
)
else:
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.run_id} has failed.",
"run_id": self.run_id,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})

async def is_still_running(self, hook: DbtCloudHook) -> bool:
"""
Async function to check whether the job is submitted via async API is in
running state and returns True if it is still running else
return False
"""
job_run_status = await hook.get_job_status(self.run_id, self.account_id)
if not DbtCloudJobRunStatus.is_terminal(job_run_status):
return True
return False