diff --git a/astronomer/providers/dbt/cloud/hooks/dbt.py b/astronomer/providers/dbt/cloud/hooks/dbt.py index f70993cc1..57b4e8fd8 100644 --- a/astronomer/providers/dbt/cloud/hooks/dbt.py +++ b/astronomer/providers/dbt/cloud/hooks/dbt.py @@ -1,3 +1,4 @@ +import warnings from functools import wraps from inspect import signature from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast @@ -58,7 +59,19 @@ async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str """Get Headers, tenants from the connection details""" headers: Dict[str, Any] = {} connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) - tenant: str = connection.schema if connection.schema else "cloud" + if connection.schema: + warnings.warn( + "The `schema` parameter is deprecated and use within a dbt Cloud connection will be removed " + "in a future version. Please use `host` instead and specify the entire tenant domain name.", + category=DeprecationWarning, + stacklevel=2, + ) + # Prior to deprecation, the connection.schema value could _only_ modify the third-level + # domain value while '.getdbt.com' was always used as the remainder of the domain name. + tenant = f"{connection.schema}.getdbt.com" + else: + tenant = connection.host or "cloud.getdbt.com" + provider_info = get_provider_info() package_name = provider_info["package-name"] version = provider_info["versions"] @@ -74,19 +87,16 @@ def get_request_url_params( """ Form URL from base url and endpoint url - :param tenant: The tenant name which is need to be replaced in base url. + :param tenant: The tenant domain name which is need to be replaced in base url. :param endpoint: Endpoint url to be requested. :param include_related: Optional. List of related fields to pull with the run. Valid values are "trigger", "job", "repository", and "environment". """ data: Dict[str, Any] = {} - base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/" + base_url = f"https://{tenant}/api/v2/accounts/" if include_related: data = {"include_related": include_related} - if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"): - url = base_url + "/" + endpoint - else: - url = (base_url or "") + (endpoint or "") + url = base_url + (endpoint or "") return url, data @provide_account_id diff --git a/tests/dbt/cloud/hooks/test_dbt_hooks.py b/tests/dbt/cloud/hooks/test_dbt_hooks.py index 075ac9f99..566b0331b 100644 --- a/tests/dbt/cloud/hooks/test_dbt_hooks.py +++ b/tests/dbt/cloud/hooks/test_dbt_hooks.py @@ -140,11 +140,11 @@ async def test_get_job_details_with_error(self, mock_get, mock_get_request_url_p @pytest.mark.parametrize( "mock_endpoint, mock_param, expected_url, expected_param", [ - ("1234/run/1234", None, "https://localhost.getdbt.com/api/v2/accounts/1234/run/1234", {}), + ("1234/run/1234", None, "https://localhost/api/v2/accounts/1234/run/1234", {}), ( "1234/run/1234", ["test"], - "https://localhost.getdbt.com/api/v2/accounts/1234/run/1234", + "https://localhost/api/v2/accounts/1234/run/1234", {"include_related": ["test"]}, ), ], @@ -158,7 +158,7 @@ def test_get_request_url_params(self, mock_endpoint, mock_param, expected_url, e @pytest.mark.asyncio @mock.patch("astronomer.providers.dbt.cloud.hooks.dbt.DbtCloudHookAsync.get_connection") - async def test_get_headers_tenants_from_connection(self, mock_get_connection): + async def test_get_headers_tenants_from_connection_host(self, mock_get_connection): """ Test get_headers_tenants_from_connection function to assert the headers response with mocked connection details""" @@ -167,7 +167,7 @@ async def test_get_headers_tenants_from_connection(self, mock_get_connection): conn_type="test", login=1234, password="newT0k3n", - schema="Tenant", + host="Tenant", extra=json.dumps( { "login": "test", @@ -184,3 +184,37 @@ async def test_get_headers_tenants_from_connection(self, mock_get_connection): headers, tenant = await hook.get_headers_tenants_from_connection() assert headers == HEADERS assert tenant == "Tenant" + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.dbt.cloud.hooks.dbt.DbtCloudHookAsync.get_connection") + async def test_get_headers_tenants_from_connection_schema(self, mock_get_connection): + """ + Test get_headers_tenants_from_connection function to assert the + headers response with mocked connection details. + + TODO: This test can be removed once using `schema` from a dbt Cloud connection has been removed from + the OSS dbt Cloud provider. + """ + mock_get_connection.return_value = Connection( + conn_id=self.CONN_ID, + conn_type="test", + login=1234, + password="newT0k3n", + schema="Tenant", + extra=json.dumps( + { + "login": "test", + "password": "newT0k3n", + "schema": "Tenant", + } + ), + ) + provider_info = get_provider_info() + package_name = provider_info["package-name"] + version = provider_info["versions"] + HEADERS["User-Agent"] = f"{package_name}-v{version}" + hook = DbtCloudHookAsync(dbt_cloud_conn_id=self.CONN_ID) + with pytest.warns(DeprecationWarning): + headers, tenant = await hook.get_headers_tenants_from_connection() + assert headers == HEADERS + assert tenant == "Tenant.getdbt.com"