Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use entire tenant domain name in dbt Cloud connection #855

Merged
merged 1 commit into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions astronomer/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import wraps
from inspect import signature
from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast
Expand Down Expand Up @@ -58,7 +59,19 @@ async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str
"""Get Headers, tenants from the connection details"""
headers: Dict[str, Any] = {}
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
tenant: str = connection.schema if connection.schema else "cloud"
if connection.schema:
warnings.warn(
"The `schema` parameter is deprecated and use within a dbt Cloud connection will be removed "
"in a future version. Please use `host` instead and specify the entire tenant domain name.",
category=DeprecationWarning,
stacklevel=2,
)
# Prior to deprecation, the connection.schema value could _only_ modify the third-level
# domain value while '.getdbt.com' was always used as the remainder of the domain name.
tenant = f"{connection.schema}.getdbt.com"
else:
tenant = connection.host or "cloud.getdbt.com"

provider_info = get_provider_info()
package_name = provider_info["package-name"]
version = provider_info["versions"]
Expand All @@ -74,19 +87,16 @@ def get_request_url_params(
"""
Form URL from base url and endpoint url

:param tenant: The tenant name which is need to be replaced in base url.
:param tenant: The tenant domain name which is need to be replaced in base url.
:param endpoint: Endpoint url to be requested.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
data: Dict[str, Any] = {}
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
base_url = f"https://{tenant}/api/v2/accounts/"
if include_related:
data = {"include_related": include_related}
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = base_url + "/" + endpoint
else:
url = (base_url or "") + (endpoint or "")
Copy link
Contributor Author

@josh-fell josh-fell Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the if/else since this line would never be executed.

url = base_url + (endpoint or "")
return url, data

@provide_account_id
Expand Down
42 changes: 38 additions & 4 deletions tests/dbt/cloud/hooks/test_dbt_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ async def test_get_job_details_with_error(self, mock_get, mock_get_request_url_p
@pytest.mark.parametrize(
"mock_endpoint, mock_param, expected_url, expected_param",
[
("1234/run/1234", None, "https://localhost.getdbt.com/api/v2/accounts/1234/run/1234", {}),
("1234/run/1234", None, "https://localhost/api/v2/accounts/1234/run/1234", {}),
(
"1234/run/1234",
["test"],
"https://localhost.getdbt.com/api/v2/accounts/1234/run/1234",
"https://localhost/api/v2/accounts/1234/run/1234",
{"include_related": ["test"]},
),
],
Expand All @@ -158,7 +158,7 @@ def test_get_request_url_params(self, mock_endpoint, mock_param, expected_url, e

@pytest.mark.asyncio
@mock.patch("astronomer.providers.dbt.cloud.hooks.dbt.DbtCloudHookAsync.get_connection")
async def test_get_headers_tenants_from_connection(self, mock_get_connection):
async def test_get_headers_tenants_from_connection_host(self, mock_get_connection):
"""
Test get_headers_tenants_from_connection function to assert the
headers response with mocked connection details"""
Expand All @@ -167,7 +167,7 @@ async def test_get_headers_tenants_from_connection(self, mock_get_connection):
conn_type="test",
login=1234,
password="newT0k3n",
schema="Tenant",
host="Tenant",
extra=json.dumps(
{
"login": "test",
Expand All @@ -184,3 +184,37 @@ async def test_get_headers_tenants_from_connection(self, mock_get_connection):
headers, tenant = await hook.get_headers_tenants_from_connection()
assert headers == HEADERS
assert tenant == "Tenant"

@pytest.mark.asyncio
@mock.patch("astronomer.providers.dbt.cloud.hooks.dbt.DbtCloudHookAsync.get_connection")
async def test_get_headers_tenants_from_connection_schema(self, mock_get_connection):
"""
Test get_headers_tenants_from_connection function to assert the
headers response with mocked connection details.

TODO: This test can be removed once using `schema` from a dbt Cloud connection has been removed from
the OSS dbt Cloud provider.
"""
mock_get_connection.return_value = Connection(
conn_id=self.CONN_ID,
conn_type="test",
login=1234,
password="newT0k3n",
schema="Tenant",
extra=json.dumps(
{
"login": "test",
"password": "newT0k3n",
"schema": "Tenant",
}
),
)
provider_info = get_provider_info()
package_name = provider_info["package-name"]
version = provider_info["versions"]
HEADERS["User-Agent"] = f"{package_name}-v{version}"
hook = DbtCloudHookAsync(dbt_cloud_conn_id=self.CONN_ID)
with pytest.warns(DeprecationWarning):
headers, tenant = await hook.get_headers_tenants_from_connection()
assert headers == HEADERS
assert tenant == "Tenant.getdbt.com"