Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use entire tenant domain name in dbt Cloud connection #28890

Merged
merged 1 commit into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import json
import time
import warnings
from enum import Enum
from functools import wraps
from inspect import signature
Expand Down Expand Up @@ -140,17 +141,14 @@ class DbtCloudHook(HttpHook):
def get_ui_field_behaviour() -> dict[str, Any]:
"""Builds custom field behavior for the dbt Cloud connection form in the Airflow UI."""
return {
"hidden_fields": ["host", "port", "extra"],
"relabeling": {"login": "Account ID", "password": "API Token", "schema": "Tenant"},
"placeholders": {"schema": "Defaults to 'cloud'."},
"hidden_fields": ["schema", "port", "extra"],
"relabeling": {"login": "Account ID", "password": "API Token", "host": "Tenant"},
"placeholders": {"host": "Defaults to 'cloud.getdbt.com'."},
}

def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
tenant = self.connection.schema if self.connection.schema else "cloud"

self.base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"

@cached_property
def connection(self) -> Connection:
Expand All @@ -161,6 +159,21 @@ def connection(self) -> Connection:
return _connection

def get_conn(self, *args, **kwargs) -> Session:
if self.connection.schema:
warnings.warn(
"The `schema` parameter is deprecated and use within a dbt Cloud connection will be removed "
"in a future version. Please use `host` instead and specify the entire tenant domain name.",
category=DeprecationWarning,
stacklevel=2,
)
# Prior to deprecation, the connection.schema value could _only_ modify the third-level
# domain value while '.getdbt.com' was always used as the remainder of the domain name.
tenant = f"{self.connection.schema}.getdbt.com"
else:
tenant = self.connection.host or "cloud.getdbt.com"

self.base_url = f"https://{tenant}/api/v2/accounts/"

session = Session()
session.auth = self.auth_type(self.connection.password)

Expand Down
19 changes: 13 additions & 6 deletions docs/apache-airflow-providers-dbt-cloud/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
specific language governing permissions and limitations
under the License.

.. spelling::

getdbt


.. _howto/connection:dbt-cloud:
Expand Down Expand Up @@ -65,10 +68,14 @@ Login (optional)
If an Account ID is not provided in an Airflow connection, ``account_id`` *must* be explicitly passed to
an operator or hook method.

Schema (optional)
The Tenant name for your dbt Cloud environment (i.e. https://my-tenant.getdbt.com). This is particularly
useful when using a single-tenant dbt Cloud instance. If a Tenant name is not provided, "cloud"
will be used as the default value (i.e. https://cloud.getdbt.com) assuming a multi-tenant instance.
Host (optional)
The Tenant domain for your dbt Cloud environment (e.g. "my-tenant.getdbt.com"). This is particularly
useful when using a single-tenant dbt Cloud instance or `other dbt Cloud regions <https://docs.getdbt.com/docs/deploy/regions-ip-addresses>`__
like EMEA or a Virtual Private dbt Cloud. If a Tenant domain is not provided, "cloud.getdbt.com" will be
used as the default value assuming a multi-tenant instance in the North America region.

If using the Connection form in the Airflow UI, the Tenant domain can also be stored in the "Tenant"
field.

When specifying the connection as an environment variable, you should specify it following the standard syntax
of a database connection. Note that all components of the URI should be URL-encoded.
Expand All @@ -88,11 +95,11 @@ For example, to add a connection with the connection ID of "dbt_cloud_default":

export AIRFLOW_CONN_DBT_CLOUD_DEFAULT='dbt-cloud://:api_token@'

When specifying Tenant name:
When specifying a Tenant domain:

.. code-block:: bash

export AIRFLOW_CONN_DBT_CLOUD_DEFAULT='dbt-cloud://:api_token@:/my-tenant'
export AIRFLOW_CONN_DBT_CLOUD_DEFAULT='dbt-cloud://:api_token@my-tenant.getdbt.com'

You can refer to the documentation on
:ref:`creating connections via environment variables <environment_variables_secrets_backend>` for more
Expand Down
19 changes: 14 additions & 5 deletions tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
SINGLE_TENANT_CONN = "single_tenant_conn"
DEFAULT_ACCOUNT_ID = 11111
ACCOUNT_ID = 22222
SINGLE_TENANT_SCHEMA = "single.tenant"
SINGLE_TENANT_DOMAIN = "single.tenant.getdbt.com"
TOKEN = "token"
PROJECT_ID = 33333
JOB_ID = 4444
Expand Down Expand Up @@ -123,18 +123,18 @@ def setup_class(self):
password=TOKEN,
)

# Connection with `schema` parameter set
schema_conn = Connection(
# Connection with `host` parameter set
host_conn = Connection(
conn_id=SINGLE_TENANT_CONN,
conn_type=DbtCloudHook.conn_type,
login=DEFAULT_ACCOUNT_ID,
password=TOKEN,
schema=SINGLE_TENANT_SCHEMA,
host=SINGLE_TENANT_DOMAIN,
)

db.merge_conn(account_id_conn)
db.merge_conn(no_account_id_conn)
db.merge_conn(schema_conn)
db.merge_conn(host_conn)

@pytest.mark.parametrize(
argnames="conn_id, url",
Expand All @@ -146,6 +146,15 @@ def test_init_hook(self, conn_id, url):
assert hook.auth_type == TokenAuth
assert hook.method == "POST"
assert hook.dbt_cloud_conn_id == conn_id

@pytest.mark.parametrize(
argnames="conn_id, url",
argvalues=[(ACCOUNT_ID_CONN, BASE_URL), (SINGLE_TENANT_CONN, SINGLE_TENANT_URL)],
ids=["multi-tenant", "single-tenant"],
)
def test_tenant_base_url(self, conn_id, url):
hook = DbtCloudHook(conn_id)
hook.get_conn()
assert hook.base_url == url

@pytest.mark.parametrize(
Expand Down