Skip to content

Commit

Permalink
Allow downloading of dbt Cloud artifacts to non-existent paths (#29048)
Browse files Browse the repository at this point in the history
Closes: #27107

Currently the DbtCloudGetJobRunArtifactOperator does not support downloading dbt Cloud job run artifacts to non-existent local paths. This means users couldn't specify a dynamic file path to which to download the artifact.

This PR will enhance the DbtCloudGetJobRunArtifactOperator to create these non-existent paths when necessary.
  • Loading branch information
josh-fell committed Jan 23, 2023
1 parent 6190e34 commit f805b41
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
12 changes: 10 additions & 2 deletions airflow/providers/dbt/cloud/operators/dbt.py
Expand Up @@ -19,6 +19,7 @@
import json
import time
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -239,18 +240,25 @@ def __init__(
self.step = step
self.output_file_name = output_file_name or f"{self.run_id}_{self.path}".replace("/", "-")

def execute(self, context: Context) -> None:
def execute(self, context: Context) -> str:
hook = DbtCloudHook(self.dbt_cloud_conn_id)
response = hook.get_job_run_artifact(
run_id=self.run_id, path=self.path, account_id=self.account_id, step=self.step
)

with open(self.output_file_name, "w") as file:
output_file_path = Path(self.output_file_name)
output_file_path.parent.mkdir(parents=True, exist_ok=True)
with output_file_path.open(mode="w") as file:
self.log.info(
"Writing %s artifact for job run %s to %s.", self.path, self.run_id, self.output_file_name
)
if self.path.endswith(".json"):
json.dump(response.json(), file)
else:
file.write(response.text)

return self.output_file_name


class DbtCloudListJobsOperator(BaseOperator):
"""
Expand Down
56 changes: 52 additions & 4 deletions tests/providers/dbt/cloud/operators/test_dbt_cloud.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import os
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -301,7 +302,7 @@ def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
operator.execute(context={})
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -310,6 +311,10 @@ def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
step=None,
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-manifest.json"
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
@pytest.mark.parametrize(
"conn_id, account_id",
Expand All @@ -328,7 +333,7 @@ def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
operator.execute(context={})
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -337,6 +342,10 @@ def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_i
step=2,
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-manifest.json"
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
@pytest.mark.parametrize(
"conn_id, account_id",
Expand All @@ -354,7 +363,7 @@ def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
)

mock_get_artifact.return_value.text = "file contents"
operator.execute(context={})
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -363,6 +372,10 @@ def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
step=None,
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-model.sql"
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
@pytest.mark.parametrize(
"conn_id, account_id",
Expand All @@ -381,7 +394,7 @@ def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

mock_get_artifact.return_value.text = "file contents"
operator.execute(context={})
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -390,6 +403,41 @@ def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_i
step=2,
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-model.sql"
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
@pytest.mark.parametrize(
"conn_id, account_id",
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_get_artifact_with_specified_output_file(self, mock_get_artifact, conn_id, account_id, tmp_path):
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
run_id=RUN_ID,
account_id=account_id,
path="run_results.json",
dag=self.dag,
output_file_name=tmp_path / "run_results.json",
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
path="run_results.json",
account_id=account_id,
step=None,
)

assert operator.output_file_name == tmp_path / "run_results.json"
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name


class TestDbtCloudListJobsOperator:
def setup_method(self):
Expand Down

0 comments on commit f805b41

Please sign in to comment.