diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index 85eba8da04f64..4f229320f060a 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -165,7 +165,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: class DbtCloudHook(HttpHook): """ - Interact with dbt Cloud using the V2 API. + Interact with dbt Cloud using the V2 (V3 if supported) API. :param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection `. """ @@ -194,7 +194,7 @@ def _get_tenant_domain(conn: Connection) -> str: @staticmethod def get_request_url_params( - tenant: str, endpoint: str, include_related: list[str] | None = None + tenant: str, endpoint: str, include_related: list[str] | None = None, *, api_version: str = "v2" ) -> tuple[str, dict[str, Any]]: """ Form URL from base url and endpoint url. @@ -207,7 +207,7 @@ def get_request_url_params( data: dict[str, Any] = {} if include_related: data = {"include_related": include_related} - url = f"https://{tenant}/api/v2/accounts/{endpoint or ''}" + url = f"https://{tenant}/api/{api_version}/accounts/{endpoint or ''}" return url, data async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]: @@ -270,7 +270,7 @@ def connection(self) -> Connection: def get_conn(self, *args, **kwargs) -> Session: tenant = self._get_tenant_domain(self.connection) - self.base_url = f"https://{tenant}/api/v2/accounts/" + self.base_url = f"https://{tenant}/" session = Session() session.auth = self.auth_type(self.connection.password) @@ -302,19 +302,22 @@ def _run_and_get_response( endpoint: str | None = None, payload: str | dict[str, Any] | None = None, paginate: bool = False, + *, + api_version: str = "v2", ) -> Any: self.method = method + full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None if paginate: if isinstance(payload, str): raise ValueError("Payload cannot be a string to paginate a response.") if endpoint: - return self._paginate(endpoint=endpoint, payload=payload) - else: - raise ValueError("An endpoint is needed to paginate a response.") + return self._paginate(endpoint=full_endpoint, payload=payload) - return self.run(endpoint=endpoint, data=payload) + raise ValueError("An endpoint is needed to paginate a response.") + + return self.run(endpoint=full_endpoint, data=payload) def list_accounts(self) -> list[Response]: """ @@ -342,7 +345,7 @@ def list_projects(self, account_id: int | None = None) -> list[Response]: :param account_id: Optional. The ID of a dbt Cloud account. :return: List of request responses. """ - return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True) + return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True, api_version="v3") @fallback_to_default_account def get_project(self, project_id: int, account_id: int | None = None) -> Response: @@ -353,7 +356,7 @@ def get_project(self, project_id: int, account_id: int | None = None) -> Respons :param account_id: Optional. The ID of a dbt Cloud account. :return: The request response. """ - return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/") + return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/", api_version="v3") @fallback_to_default_account def list_jobs( diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py b/tests/providers/dbt/cloud/hooks/test_dbt.py index 39d31a444b0c6..9a65ba0acf05d 100644 --- a/tests/providers/dbt/cloud/hooks/test_dbt.py +++ b/tests/providers/dbt/cloud/hooks/test_dbt.py @@ -46,8 +46,8 @@ JOB_ID = 4444 RUN_ID = 5555 -BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/" -SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/" +BASE_URL = "https://cloud.getdbt.com/" +SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/" class TestDbtCloudJobRunStatus: @@ -211,7 +211,7 @@ def test_get_account(self, mock_http_run, mock_paginate, conn_id, account_id): assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/", data=None) + hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", data=None) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -229,7 +229,9 @@ def test_list_projects(self, mock_http_run, mock_paginate, conn_id, account_id): _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_not_called() - hook._paginate.assert_called_once_with(endpoint=f"{_account_id}/projects/", payload=None) + hook._paginate.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None + ) @pytest.mark.parametrize( argnames="conn_id, account_id", @@ -245,7 +247,9 @@ def test_get_project(self, mock_http_run, mock_paginate, conn_id, account_id): assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/projects/{PROJECT_ID}/", data=None) + hook.run.assert_called_once_with( + endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None + ) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -263,7 +267,7 @@ def test_list_jobs(self, mock_http_run, mock_paginate, conn_id, account_id): _account_id = account_id or DEFAULT_ACCOUNT_ID hook._paginate.assert_called_once_with( - endpoint=f"{_account_id}/jobs/", payload={"order_by": None, "project_id": None} + endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": None, "project_id": None} ) hook.run.assert_not_called() @@ -282,7 +286,8 @@ def test_list_jobs_with_payload(self, mock_http_run, mock_paginate, conn_id, acc _account_id = account_id or DEFAULT_ACCOUNT_ID hook._paginate.assert_called_once_with( - endpoint=f"{_account_id}/jobs/", payload={"order_by": "-id", "project_id": PROJECT_ID} + endpoint=f"api/v2/accounts/{_account_id}/jobs/", + payload={"order_by": "-id", "project_id": PROJECT_ID}, ) hook.run.assert_not_called() @@ -300,7 +305,7 @@ def test_get_job(self, mock_http_run, mock_paginate, conn_id, account_id): assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/jobs/{JOB_ID}", data=None) + hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -319,7 +324,7 @@ def test_trigger_job_run(self, mock_http_run, mock_paginate, conn_id, account_id _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/jobs/{JOB_ID}/run/", + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}), ) hook._paginate.assert_not_called() @@ -348,7 +353,7 @@ def test_trigger_job_run_with_overrides(self, mock_http_run, mock_paginate, conn _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/jobs/{JOB_ID}/run/", + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", data=json.dumps( {"cause": cause, "steps_override": steps_override, "schema_override": schema_override} ), @@ -376,7 +381,7 @@ def test_trigger_job_run_with_additional_run_configs( _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/jobs/{JOB_ID}/run/", + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", data=json.dumps( { "cause": cause, @@ -405,7 +410,7 @@ def test_list_job_runs(self, mock_http_run, mock_paginate, conn_id, account_id): _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_not_called() hook._paginate.assert_called_once_with( - endpoint=f"{_account_id}/runs/", + endpoint=f"api/v2/accounts/{_account_id}/runs/", payload={ "include_related": None, "job_definition_id": None, @@ -431,7 +436,7 @@ def test_list_job_runs_with_payload(self, mock_http_run, mock_paginate, conn_id, _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_not_called() hook._paginate.assert_called_once_with( - endpoint=f"{_account_id}/runs/", + endpoint=f"api/v2/accounts/{_account_id}/runs/", payload={ "include_related": ["job"], "job_definition_id": JOB_ID, @@ -452,7 +457,7 @@ def test_get_job_runs(self, mock_http_run, conn_id, account_id): assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/", data=None) + hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None) @pytest.mark.parametrize( argnames="conn_id, account_id", @@ -469,7 +474,7 @@ def test_get_job_run(self, mock_http_run, mock_paginate, conn_id, account_id): _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": None} ) hook._paginate.assert_not_called() @@ -488,7 +493,7 @@ def test_get_job_run_with_payload(self, mock_http_run, mock_paginate, conn_id, a _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]} ) hook._paginate.assert_not_called() @@ -543,7 +548,9 @@ def test_cancel_job_run(self, mock_http_run, mock_paginate, conn_id, account_id) assert hook.method == "POST" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/cancel/", data=None) + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None + ) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -561,7 +568,7 @@ def test_list_job_run_artifacts(self, mock_http_run, mock_paginate, conn_id, acc _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None} ) hook._paginate.assert_not_called() @@ -579,7 +586,9 @@ def test_list_job_run_artifacts_with_payload(self, mock_http_run, mock_paginate, assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2}) + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2} + ) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -598,7 +607,7 @@ def test_get_job_run_artifact(self, mock_http_run, mock_paginate, conn_id, accou _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None} ) hook._paginate.assert_not_called() @@ -618,7 +627,7 @@ def test_get_job_run_artifact_with_payload(self, mock_http_run, mock_paginate, c _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2} ) hook._paginate.assert_not_called()